from typing import NamedTuple, List, Any, Tuple

import torch


class TraceStep(NamedTuple):
    """A single step in a sequence"""
    state: torch.device
    action: int
    reward: float
    next_state: torch.device
    done: bool


class RolloutSample(NamedTuple):
    """
    A series of steps, but batched for efficiency.
    These aren't necessarily sequential; random sampling can produce this.
    All of these have the same length
    """
    states: torch.Tensor
    actions: torch.Tensor
    rewards: torch.Tensor
    next_states: torch.Tensor
    dones: torch.Tensor


class CircularRolloutBuffer:
    def __init__(self, capacity, input_shape,
                 state_dtype: torch.dtype = torch.uint8, device: torch.device = "cpu"):
        """
        Arguments:
            capacity: Integer, Number of stored transitions
            input_shape: Shape of the preprocessed frame
            device: Where to store everything
        """
        self.capacity = capacity
        self.input_shape = input_shape
        self.device = device
        self.state_dtype = state_dtype

        self.write_head = 0  # INVARIANT: Always 0 <= write_head < capacity
        self.num_filled = 0  # All spaces that are filled, including meaningless buffer

        self.actions = torch.zeros((self.capacity,), dtype=torch.long, device=device)
        self.rewards = torch.zeros((self.capacity,), dtype=torch.float32, device=device)
        self.states = torch.zeros((self.capacity, *self.input_shape), dtype=state_dtype, device=device)
        self.next_states = torch.zeros((self.capacity, *self.input_shape), dtype=state_dtype, device=device)
        self.terminal_flags = torch.zeros((self.capacity,), dtype=torch.bool, device=device)
        self.valid_samples = torch.zeros((self.capacity,), dtype=torch.bool, device=device)
        self.priorities = torch.zeros((self.capacity,), dtype=torch.float, device=device)

    def write_to_and_move_head_batch(self, mapping: List[Tuple[torch.Tensor, torch.Tensor]]):
        # See add_episode for example of usage

        len_to_write = mapping[0][1].shape[0]  # The first src tensor's first dimension

        for dest_tensor, src_tensor in mapping:
            if self.write_head + len_to_write <= self.capacity:
                # Everything fits before we need to roll over to the beginning of the tape
                dest_tensor[self.write_head: self.write_head + len_to_write] = src_tensor
            else:
                # Need to split the write in two sections
                write_len_before_split = self.capacity - self.write_head
                write_len_after_split = len_to_write - write_len_before_split

                dest_tensor[self.write_head:] = src_tensor[:write_len_before_split]
                dest_tensor[:write_len_after_split] = src_tensor[write_len_before_split:]

        self.write_head = (self.write_head + len_to_write) % self.capacity

    def write_to_and_move_head(self, mapping: List[Tuple[torch.Tensor, Any]]):
        for dest_tensor, src_tensor in mapping:
            dest_tensor[self.write_head] = src_tensor

        self.write_head = (self.write_head + 1) % self.capacity

    def clear_between(self, start_idx, end_idx):
        """
        Zero out between two indices. Does not update num_filled. ASSUMPTION end_idx >= start_idx
        :param end_idx: Exclusive
        """
        self.actions[start_idx:end_idx] = 0
        self.rewards[start_idx:end_idx] = 0
        self.states[start_idx:end_idx] = 0
        self.next_states[start_idx:end_idx] = 0
        self.terminal_flags[start_idx:end_idx] = 0
        self.valid_samples[start_idx:end_idx] = False
        self.priorities[start_idx:end_idx] = 0

    def ensure_space_and_zero_out(self, space_required: int):
        """
        Zero out enough so that there are space_required steps available in front of the write_head,
        plus to the end of an episode if we end up zeroing out to the middle of an episode
        """
        free_spaces_in_front_of_write_head = self.capacity - self.num_filled
        if free_spaces_in_front_of_write_head >= space_required:
            return

        next_entry_allowed_to_be_filled = (self.write_head + space_required) % self.capacity
        clear_between_erase_head_and_this_idx = (next_entry_allowed_to_be_filled - 1) % self.capacity
        erase_head = (self.write_head + free_spaces_in_front_of_write_head) % self.capacity

        # The if condition at the beginning of the function should prevent this from happening
        assert clear_between_erase_head_and_this_idx != -1, "Trying to zero out non-existing episode"

        if erase_head <= clear_between_erase_head_and_this_idx:  # No need to loop around
            self.clear_between(erase_head, clear_between_erase_head_and_this_idx)
        else:  # Loop around to the start
            self.clear_between(erase_head, self.capacity)
            self.clear_between(0, clear_between_erase_head_and_this_idx)

        numel_erased = (clear_between_erase_head_and_this_idx - erase_head) % self.capacity
        self.num_filled -= numel_erased

    def add_step(self, t_step: TraceStep):
        self.ensure_space_and_zero_out(1)

        max_priority = max(float(self.priorities.max()), 1.0)

        self.write_to_and_move_head([
            (self.actions, t_step.action),
            (self.rewards, t_step.reward),
            (self.states, t_step.state),
            (self.next_states, t_step.next_state),
            (self.terminal_flags, t_step.done),
            (self.valid_samples, True),
            (self.priorities, max_priority)
        ])

        self.num_filled += 1

    def add_episode(self, trace_steps: List[TraceStep]):
        total_buffer_space_required = len(trace_steps)
        self.ensure_space_and_zero_out(total_buffer_space_required)

        for t_step in trace_steps:
            self.add_step(t_step)

    def get_states_and_next_states(self, indices) -> Tuple[torch.Tensor, torch.Tensor]:
        return self.states[indices].float(), self.next_states[indices].float()

    def get_rollout_sample_from_indices(self, indices: torch.Tensor) -> RolloutSample:
        current_states, next_states = self.get_states_and_next_states(indices)

        return RolloutSample(
            states=current_states,
            actions=self.actions[indices],
            rewards=self.rewards[indices],
            next_states=next_states,
            dones=self.terminal_flags[indices],
        )

    def sample(self, batch_size: int, priority_scale=0.0) -> Tuple[
        RolloutSample, torch.Tensor, torch.Tensor]:
        """
        Sample from the rollout buffer
        """
        scaled_priorities = torch.pow(self.priorities, priority_scale)

        sample_probs = self.valid_samples.float() * scaled_priorities
        sample_probs = sample_probs / sample_probs.sum()

        indices = torch.multinomial(input=sample_probs, num_samples=batch_size, replacement=True)

        importance = 1 / sample_probs[indices]
        importance = importance / importance.max()

        return self.get_rollout_sample_from_indices(indices), indices, importance

    def num_filled_approx(self) -> int:
        return self.num_filled

    def set_priorities(self, indices, errors, offset=0.1):
        self.priorities[indices] = errors.abs() + offset

    def reset_all_priorities(self):
        self.priorities = self.valid_samples.float()

    def state_dict(self):
        return {
            "write_head": self.write_head,
            "num_filled": self.num_filled,
            "actions": self.actions,
            "rewards": self.rewards,
            "states": self.states,
            "next_states": self.next_states,
            "terminal_flags": self.terminal_flags,
            "valid_samples": self.valid_samples,
            "priorities": self.priorities
        }

    def load_state_dict(self, state_dict):
        self.write_head = state_dict["write_head"]
        self.num_filled = state_dict["num_filled"]
        self.actions = state_dict["actions"]
        self.rewards = state_dict["rewards"]
        self.states = state_dict["states"]
        self.next_states = state_dict["next_states"]
        self.terminal_flags = state_dict["terminal_flags"]
        self.valid_samples = state_dict["valid_samples"]
        self.priorities = state_dict["priorities"]
